#!/usr/bin/env python

# pip-audit-bulk: run pip-audit in bulk, saving (reduced) reports in JSON

import collections
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path
from tempfile import NamedTemporaryFile

# This entire script is path-sensitive.
_HERE = Path(__file__).resolve().parent
_SITE = _HERE / "site"

assert _SITE.is_dir(), "missing site dir (wrong dir?)"

# formula-name -> [requirements]
_REQUIREMENTS_JSON = _SITE / "formula-requirements.json"

_REQUIREMENTS = json.loads(_REQUIREMENTS_JSON.read_text())

_AUDITS_JSON = _SITE / "formula-audits.json"

audits = {
    "vulnerable": {},
    "failures": [],
    "skipped": [],  # NOTE: currently unused
}

if "CI" in os.environ:
    _SUMMARY = Path(os.environ["GITHUB_STEP_SUMMARY"])
else:
    _SUMMARY = None

# {(dependency, CVE): [list of formula]}
_SUMMARY_RESULTS = collections.defaultdict(list)

assert shutil.which("osv-scanner"), "missing osv-scanner"

for formula, requirements in _REQUIREMENTS.items():
    # osv-scanner wants a requirements.txt file, so we make a temporary one.
    with NamedTemporaryFile(mode="w") as req_file:
        req_file.write("\n".join(requirements))
        req_file.flush()

        result = subprocess.run(
            [
                "osv-scanner",
                "--lockfile",
                f"requirements.txt:{req_file.name}",
                "--format=json",
                "--config=osv-scanner-config.toml",
            ],
            capture_output=True,
            text=True,
        )

    # If we don't see anything on stdout, osv-scanner probably failed.
    # Log it as an error.
    if not result.stdout:
        print(f"\t>:( osv-scanner failed: {result.stderr}", file=sys.stderr)
        audits["failures"].append(formula)
        continue

    audit = json.loads(result.stdout)
    results = audit["results"]

    # If we're left with anything, dump it as a result.
    if results:
        assert len(results) == 1
        vulnerable_packages = results[0]["packages"]
        vuln_count = 0
        for p in vulnerable_packages:
            vuln_count += len(p["groups"])
            for v in p["groups"]:
                _SUMMARY_RESULTS[(p["package"]["name"], v["ids"][0])].append(formula)
        print(f"\t:-( found {vuln_count} vulnerabilities in {formula}", file=sys.stderr)
        audits["vulnerable"][formula] = vulnerable_packages
    else:
        print(f"\t:-) no vulnerabilities found in {formula}", file=sys.stderr)


_AUDITS_JSON.write_text(json.dumps(audits, indent=2))

render = """Identified the following vulnerabilities:

| Dependency | Vulnerability | Formula |
| ---------- | ------------- | ------- |
"""

for (dep, vuln), formulas in sorted(_SUMMARY_RESULTS.items()):
    formulas_rendered = ", ".join(f"`{f}`" for f in formulas)
    render += f"| `{dep}` | [`{vuln}`](https://osv.dev/vulnerability/{vuln}) | {formulas_rendered} |\n"

if _SUMMARY:
    with _SUMMARY.open(mode="a") as io:
        print(render, file=io)
